package protoutils

import (
	"reflect"
	"testing"

	"github.com/stackrox/rox/generated/test"
	"github.com/stackrox/rox/pkg/protoassert"
	"github.com/stackrox/rox/pkg/protocompat"
	"github.com/stackrox/rox/pkg/protoreflect"
	"github.com/stretchr/testify/assert"
	"google.golang.org/protobuf/types/known/anypb"
)

func getFilledStruct() *test.TestClone {
	return &test.TestClone{
		IntSlice:    []int32{1, 2, 3},
		StringSlice: []string{"1", "2", "3"},
		SubMessages: []*test.TestCloneSubMessage{
			{
				Int32:   1,
				String_: "1",
			},
			{
				Int32:   2,
				String_: "2",
			},
		},
		MessageMap: map[string]*test.TestCloneSubMessage{
			"1": {
				Int32:   1,
				String_: "1",
			},
			"2": {
				Int32:   2,
				String_: "2",
			},
		},
		StringMap: map[string]string{
			"1": "1a",
			"2": "2a",
		},
		EnumSlice: []test.TestClone_CloneEnum{test.TestClone_UNSET, test.TestClone_Val2},
		Ts:        protocompat.TimestampNow(),
		Any: &anypb.Any{
			TypeUrl: "type url",
			Value:   []byte("123"),
		},
	}
}

func TestAutogeneratedClone(t *testing.T) {
	// All nil test case
	val := &test.TestClone{}
	assert.True(t, val.EqualVT(val.CloneVT()))

	val = getFilledStruct()
	protoassert.Equal(t, val, val.CloneVT())

	val = getFilledStruct()
	cloned := val.CloneVT()
	val.IntSlice[0] = 100
	assert.False(t, val.EqualVT(cloned))

	val = getFilledStruct()
	cloned = val.CloneVT()
	val.StringSlice[0] = "100"
	assert.False(t, val.EqualVT(cloned))

	val = getFilledStruct()
	cloned = val.CloneVT()
	val.SubMessages[0].Int32 = 100
	assert.False(t, val.EqualVT(cloned))

	val = getFilledStruct()
	cloned = val.CloneVT()
	delete(val.GetMessageMap(), "1")
	assert.False(t, val.EqualVT(cloned))

	val = getFilledStruct()
	cloned = val.CloneVT()
	delete(val.GetStringMap(), "1")
	assert.False(t, val.EqualVT(cloned))

	val = getFilledStruct()
	cloned = val.CloneVT()
	val.EnumSlice[0] = test.TestClone_Val1
	assert.False(t, val.EqualVT(cloned))

	val = getFilledStruct()
	cloned = val.CloneVT()
	val.Ts.Seconds = 100000
	assert.False(t, val.EqualVT(cloned))
}

func TestAutogeneratedCloneOneOfs(t *testing.T) {
	// All nil test case
	val := &test.TestClone{
		Primitive: &test.TestClone_Int32{
			Int32: 10,
		},
	}
	assert.True(t, val.EqualVT(val.CloneVT()))

	val = &test.TestClone{
		Primitive: &test.TestClone_String_{
			String_: "10",
		},
	}
	assert.True(t, val.EqualVT(val.CloneVT()))

	val = &test.TestClone{
		Primitive: &test.TestClone_Msg{
			Msg: &test.TestCloneSubMessage{
				Int32:   10,
				String_: "10",
			},
		},
	}
	assert.True(t, val.EqualVT(val.CloneVT()))
}

func checkPointers(t *testing.T, orig, cloned reflect.Value) {
	origPtr := orig.Pointer()
	clonedPtr := cloned.Pointer()
	if origPtr == 0 && clonedPtr == 0 {
		return
	}
	assert.NotEqual(t, orig.Pointer(), cloned.Pointer())
}

func checkAliasRecursive(t *testing.T, orig, cloned reflect.Value) {
	switch orig.Kind() {
	case reflect.Array, reflect.Slice:
		checkPointers(t, orig, cloned)
		for i := 0; i < orig.Len(); i++ {
			checkAliasRecursive(t, orig.Index(i), cloned.Index(i))
		}
	case reflect.Interface:
		checkAliasRecursive(t, orig.Elem(), cloned.Elem())
	case reflect.Map:
		iter := orig.MapRange()
		for iter.Next() {
			checkAliasRecursive(t, iter.Value(), cloned.MapIndex(iter.Key()))
		}
	case reflect.Ptr:
		checkPointers(t, orig, cloned)
		checkAliasRecursive(t, orig.Elem(), cloned.Elem())
	case reflect.Struct:
		for i := 0; i < orig.NumField(); i++ {
			if protoreflect.IsInternalGeneratorField(orig.Type().Field(i)) {
				continue
			}
			checkAliasRecursive(t, orig.Field(i), cloned.Field(i))
		}
	case reflect.UnsafePointer:
		checkPointers(t, orig, cloned)
	}
}

func TestCheckAliasing(t *testing.T) {
	obj := getFilledStruct()
	clonedObj := obj.CloneVT()

	checkAliasRecursive(t, reflect.ValueOf(obj), reflect.ValueOf(clonedObj))
}
